library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────────────────────────────────────────────────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr     1.1.4     ✔ readr     2.1.5
## ✔ forcats   1.0.0     ✔ stringr   1.5.1
## ✔ ggplot2   3.5.1     ✔ tibble    3.2.1
## ✔ lubridate 1.9.3     ✔ tidyr     1.3.1
## ✔ purrr     1.0.2     
## ── Conflicts ────────────────────────────────────────────────────────────────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(glue)
source("util.R")

clamp <- function(x, max, min = -max){
  case_when(
    x > max ~ max,
    x < min ~ min,
    .default = x
  )
}
pert_res <- bind_rows(readRDS("../benchmark/output/double_perturbation_results_predictions.RDS"))
parameters <- readRDS(file.path("../benchmark/output/double_perturbation_results_parameters.RDS")) %>%
  map(\(p) tibble(id = p$id, name = p$name, parameters = as_tibble(p$parameters), 
                  train = names(p$test_train_labels), perturbation = p$test_train_labels)) %>%
  bind_rows() %>%
  unnest(perturbation) %>%
  unpack(parameters)
res <- pert_res %>%
  mutate(perturbation_split = str_split(perturbation, pattern = "[+_]", n = 2)) %>%
  mutate(perturbation_split = map(perturbation_split, \(x) {
    if(all(x == "ctrl" | x == "")) "ctrl" 
    else if(length(x) == 2) x
    else c(x, "ctrl")
  })) %>%
  mutate(perturbation = map_chr(perturbation_split, paste0, collapse = "+")) %>%
  tidylog::left_join(parameters, by = c("id", "name", "perturbation")) %>%  # Matches most of x. Non matches are from scGPT and are not in training
  tidylog::filter(! is.na(train)) %>%
  separate(name, sep = "-", into = c("dataset_name2", "seed2", "method"), convert = TRUE) %>%
  tidylog::filter(dataset_name2 == dataset_name | seed2 == seed) %>%
  dplyr::select(-c(dataset_name2, seed2)) %>%
  filter(method != "lpm")
## left_join: added 7 columns (dataset_name, test_train_config_id, seed, perturbation_type, model_type, …)
##            > rows only in x                0
##            > rows only in parameters (     0)
##            > matched rows             11,250
##            >                         ========
##            > rows total               11,250
## filter: no rows removed
## filter: no rows removed
res
## # A tibble: 11,250 × 13
##    id           method perturbation prediction prediction_std perturbation_split
##    <chr>        <chr>  <chr>        <named li> <named list>   <list>            
##  1 6248f7c56f1… scgpt  AHR+FEV      <dbl>      <NULL>         <chr [2]>         
##  2 6248f7c56f1… scgpt  AHR+KLF1     <dbl>      <NULL>         <chr [2]>         
##  3 6248f7c56f1… scgpt  AHR+ctrl     <dbl>      <NULL>         <chr [2]>         
##  4 6248f7c56f1… scgpt  ARID1A+ctrl  <dbl>      <NULL>         <chr [2]>         
##  5 6248f7c56f1… scgpt  ARRDC3+ctrl  <dbl>      <NULL>         <chr [2]>         
##  6 6248f7c56f1… scgpt  ATL1+ctrl    <dbl>      <NULL>         <chr [2]>         
##  7 6248f7c56f1… scgpt  BAK1+ctrl    <dbl>      <NULL>         <chr [2]>         
##  8 6248f7c56f1… scgpt  BCL2L11+BAK1 <dbl>      <NULL>         <chr [2]>         
##  9 6248f7c56f1… scgpt  BCL2L11+TGF… <dbl>      <NULL>         <chr [2]>         
## 10 6248f7c56f1… scgpt  BCL2L11+ctrl <dbl>      <NULL>         <chr [2]>         
## # ℹ 11,240 more rows
## # ℹ 7 more variables: dataset_name <chr>, test_train_config_id <chr>,
## #   seed <int>, perturbation_type <chr>, model_type <chr>, epochs <dbl>,
## #   train <chr>
res %>%
  filter(method == "ground_truth" & seed == 1) %>%
  mutate(n_pert = lengths(map(perturbation_split, \(x) setdiff(x, "ctrl")))) %>%
  dplyr::count(dataset_name, n_pert) 
## # A tibble: 3 × 3
##   dataset_name             n_pert     n
##   <chr>                     <int> <int>
## 1 norman_from_scfoundation      0     1
## 2 norman_from_scfoundation      1   100
## 3 norman_from_scfoundation      2   124
long2matrix <- function(x, rows, cols, values, ...){
  df_mat <- x |>
    transmute({{rows}}, {{cols}}, {{values}}) |>
    pivot_wider(id_cols = {{rows}}, names_from = {{cols}}, values_from = {{values}}, ...) 
  mat<- as.matrix(df_mat[,-1])
  rownames(mat) <- df_mat[[1]]
  mat
}

res |>
  filter(seed == 1) |>
  mutate(present = map_lgl(prediction, \(x) ! is.na(x[1]))) |>
  (\(data){
    mat <- long2matrix(data, rows = method, cols = perturbation, values = present, values_fn = \(x) x * 1.0) 
    mat[is.na(mat)] <- 0
    ComplexHeatmap::pheatmap(mat, main = "Valid perturbations", breaks = c(0,1), color = c("lightgrey", "darkred"),
                             show_row_dend = FALSE, show_column_dend = FALSE, show_colnames = FALSE, legend = FALSE)
  })()

baselines <- res %>%
  filter(method == "ground_truth" & perturbation == "ctrl") %>%
  dplyr::select(baseline = prediction, dataset_name, seed)
res <- bind_rows(res, res %>%
  distinct(perturbation, perturbation_split, dataset_name, test_train_config_id, seed, train) %>%
  inner_join(baselines %>% dplyr::rename(prediction = baseline), by = c("dataset_name", "seed")) %>%
  mutate(method = "no_change"))
expr_rank_df <- res %>%
  filter(method == "ground_truth" & perturbation == "ctrl") %>%
  dplyr::select(dataset_name, seed, observed = prediction) %>%
  mutate(gene_name = map(observed, names)) %>%
  unnest(c(gene_name, observed)) %>%
  mutate(expr_rank = rank(desc(observed), ties = "first"), .by = c(seed, dataset_name)) %>%
  dplyr::select(dataset_name, seed, gene_name, expr_rank)
  
de_rank_df <- res %>%
  filter(method == "ground_truth") %>% 
  dplyr::select(dataset_name, seed, perturbation, observed = prediction) %>%
  mutate(gene_name = map(observed, names)) %>%
  unnest(c(gene_name, observed)) %>%
  left_join(baselines |> mutate(gene_name = map(baseline, names)) |> unnest(c(gene_name, baseline)), by = c("dataset_name", "seed", "gene_name")) %>%
  mutate(de = abs(observed - baseline)) %>%
  mutate(de_rank = rank(desc(de), ties = "first"), .by = c(seed, dataset_name, perturbation)) %>%
  dplyr::select(dataset_name, seed, perturbation, gene_name, de_rank)
mem.maxVSize(vsize = Inf)
## [1] Inf
contr_res <- tidylog::full_join(filter(res, method != "ground_truth"),
                                filter(res, method == "ground_truth") %>% 
                                  dplyr::select(dataset_name, seed, perturbation, observed = prediction),
           by = c("dataset_name", "seed", "perturbation"))
## full_join: added one column (observed)
##            > rows only in filter(res, method != "..       0
##            > rows only in filter(res, method == "..       0
##            > matched rows                            11,250
##            >                                        ========
##            > rows total                              11,250
res_metrics <- contr_res %>%
  tidylog::left_join(baselines, by = c("dataset_name", "seed")) %>%
  dplyr::select(-c(id, test_train_config_id)) %>%
  mutate(gene_name = map(prediction, names)) %>%
  unnest(c(gene_name, prediction, observed, baseline)) %>%
  inner_join(expr_rank_df %>% dplyr::select(dataset_name, seed, gene_name, expr_rank) %>% filter(expr_rank <= 1000), by = c("dataset_name", "seed", "gene_name")) %>%
  summarize(r2 = cor(prediction, observed),
         r2_delta = cor(prediction - baseline, observed - baseline),
         l2 =sqrt(sum((prediction - observed)^2)),
         .by = c(dataset_name, seed, method, perturbation, train))
## left_join: added one column (baseline)
##            > rows only in x               0
##            > rows only in baselines (     0)
##            > matched rows            11,250
##            >                        ========
##            > rows total              11,250
## Warning: There were 1170 warnings in `summarize()`.
## The first warning was:
## ℹ In argument: `r2_delta = cor(prediction - baseline, observed - baseline)`.
## ℹ In group 225: `dataset_name = "norman_from_scfoundation"`, `seed = 1`, `method = "scgpt"`, `perturbation = "ctrl"`, `train = "train"`.
## Caused by warning in `cor()`:
## ! the standard deviation is zero
## ℹ Run `dplyr::last_dplyr_warnings()` to see the 1169 remaining warnings.
res_metrics
## # A tibble: 11,250 × 8
##    dataset_name              seed method perturbation train    r2 r2_delta    l2
##    <chr>                    <int> <chr>  <chr>        <chr> <dbl>    <dbl> <dbl>
##  1 norman_from_scfoundation     1 scgpt  AHR+FEV      test  0.968    0.398  7.40
##  2 norman_from_scfoundation     1 scgpt  AHR+KLF1     val   0.989    0.118  4.76
##  3 norman_from_scfoundation     1 scgpt  AHR+ctrl     train 0.996    0.570  2.72
##  4 norman_from_scfoundation     1 scgpt  ARID1A+ctrl  train 0.990    0.694  4.82
##  5 norman_from_scfoundation     1 scgpt  ARRDC3+ctrl  train 0.999    0.644  1.57
##  6 norman_from_scfoundation     1 scgpt  ATL1+ctrl    train 0.993    0.751  3.88
##  7 norman_from_scfoundation     1 scgpt  BAK1+ctrl    train 0.997   -0.217  2.88
##  8 norman_from_scfoundation     1 scgpt  BCL2L11+BAK1 train 0.996   -0.138  3.10
##  9 norman_from_scfoundation     1 scgpt  BCL2L11+TGF… val   0.997    0.216  2.65
## 10 norman_from_scfoundation     1 scgpt  BCL2L11+ctrl train 0.998    0.101  2.47
## # ℹ 11,240 more rows
method_labels <- c("no_change" = "No Change", "additive_model" = "Additive",
                   "scgpt" = "scGPT", "scfoundation" = "scFoundation",
                   "uce" = "UCE*", "scbert" = "scBERT*", "geneformer" = "Geneformer*",
                   "gears" = "GEARS", "cpa" = "CPA")
dataset_labels <- c("norman_from_scfoundation" = "Norman")
approach_labels <- c("baseline" = "Baselines", "foundation_model" = "Foundation Models", "deep_learning" = "Other Deep Learning Models")

approach_annot <- tibble(method = names(method_labels)) |>
  mutate(approach = case_when(
    method %in% c("no_change", "additive_model") ~ "baseline",
    method %in% c("scgpt", "scfoundation", "uce", "uce33", "scbert", "geneformer") ~ "foundation_model",
    method %in% c("gears", "cpa") ~ "deep_learning",
    .default = "Forgot"
  )) |>
  mutate(method = factor(method, levels = names(method_labels))) %>%
  mutate(approach = factor(approach, levels = names(approach_labels))) 

sel_perts <- res_metrics %>%
  filter(seed == 1) %>%
  filter(perturbation %in% c("CEBPE+CEBPB"))

main_pl_data <- res_metrics %>%
  filter(train %in% c("test", "val")) %>%
  filter(method %in% names(method_labels)) %>%
  mutate(method = factor(method, levels = names(method_labels))) %>%
  mutate(dataset_name = factor(dataset_name, levels = names(dataset_labels))) %>%
  left_join(approach_annot, by = "method") %>%
  mutate(label = paste0(method_labels[as.character(method)], "|", approach_labels[as.character(approach)])) |>
  mutate(label = fct_reorder(label, as.integer(approach) * 1000 + as.integer(method) )) %>%
  left_join(sel_perts %>% distinct(seed, method, perturbation) %>% mutate(highlight = TRUE), by = c("seed", "method", "perturbation")) %>%
  replace_na(list(highlight = FALSE)) 

main_pl_double_pearson <- main_pl_data %>%
  ggplot(aes(x = label, y = r2_delta)) +
    geom_hline(yintercept = c(0, 1), color = "black", linewidth = 0.2) +
    ggbeeswarm::geom_quasirandom(aes(color = highlight, size = highlight)) +
    geom_hline(data = . %>% summarize(r2_delta_best = mean(r2_delta), .by = method) %>% slice_max(r2_delta_best, n = 1, with_ties = FALSE),
                 aes(yintercept = r2_delta_best), color = "grey", linetype = "dashed") +
    ungeviz::geom_hpline(data = . %>% summarize(r2_delta = mean(r2_delta), .by = c(approach, label)), aes(y = r2_delta), 
                         color = "red", width = 0.6, linewidth = 0.6) +
    ggforce::facet_row(vars(approach), scales = "free_x", space = "free", labeller = as_labeller(approach_labels)) +
    scale_color_manual(values = c("TRUE" = "orange", "FALSE" = alpha("#444444", 0.6))) +
    scale_size_manual(values = c("TRUE" = 0.6, "FALSE" = 0.1)) +
    scale_x_discrete(expand = expansion(add = 0.9)) +
    scale_y_continuous(limits = c(-0.25, 1), expand = expansion(add = 0)) +
    labs(y = "Pearson delta") +
    guides(x = legendry::guide_axis_nested(key = legendry::key_range_auto(sep = "\\|")),
           color = "none", size = "none") +
    theme(axis.title.x = element_blank(),
          panel.grid.major.y = element_line(color = "lightgrey", linewidth = 0.1),
          panel.grid.minor.y = element_line(color = "lightgrey", linewidth = 0.1),
          strip.background = element_blank(), strip.text = element_blank(),
          panel.spacing.x = unit(0, "pt"))

main_pl_double_l2 <- main_pl_data %>%
  ggplot(aes(x = label, y = l2)) +
    geom_hline(yintercept = 0, color = "black", linewidth = 0.2) +
    ggbeeswarm::geom_quasirandom(aes(color = highlight, size = highlight)) +
    geom_hline(data = . %>% summarize(l2_best = mean(l2), .by = method) %>% slice_min(l2_best, n = 1, with_ties = FALSE),
                 aes(yintercept = l2_best), color = "grey", linetype = "dashed") +
    ungeviz::geom_hpline(data = . %>% summarize(l2 = mean(l2), .by = c(approach, label)), aes(y = l2), 
                         color = "red", width = 0.6, linewidth = 0.6) +
    ggbezier::geom_bezier(data = tibble(approach = factor("baseline",  levels(main_pl_data$approach)), x = c(1, 1.5), y = c(8.5, 10)),
                          aes(x = x, y = y, angle = c(0, 90)), arrow = grid::arrow(type = "closed", ends = "first", length = unit(1, "mm"))) +
    geom_text(data = tibble(x = 1.5, y = 10.1, text = "CEBPE+CEBPB", approach = factor("baseline",  levels(main_pl_data$approach))),
                            aes(x=x, y=y, label=text), hjust = 0.2, vjust = 0, size = font_size_tiny / .pt) +
    ggforce::facet_row(vars(approach), scales = "free_x", space = "free", labeller = as_labeller(approach_labels)) +
    scale_x_discrete(expand = expansion(add = 0.7)) +
    scale_y_continuous(limits = c(0, 12.5), expand = expansion(add = c(0, 0.5))) +
    scale_color_manual(values = c("TRUE" = "orange", "FALSE" = alpha("#444444", 0.6))) +
    scale_size_manual(values = c("TRUE" = 0.6, "FALSE" = 0.1)) +
    guides(x = legendry::guide_axis_nested(key = legendry::key_range_auto(sep = "\\|")), 
           color = "none", size = "none") +
    labs(y = "Prediction error ($L_2$)") +
    theme(axis.title.x = element_blank(),
          panel.grid.major.y = element_line(color = "lightgrey", linewidth = 0.1),
          panel.grid.minor.y = element_line(color = "lightgrey", linewidth = 0.1),
          strip.background = element_blank(), strip.text = element_blank(),
          panel.spacing.x = unit(0, "pt"))

main_pl_double_pearson
## Warning: Removed 326 rows containing missing values or values outside the scale
## range (`position_quasirandom()`).
## Warning: Removed 1 row containing missing values or values outside the scale
## range (`geom_hpline()`).

main_pl_double_l2
## Warning: Removed 147 rows containing missing values or values outside the scale
## range (`position_quasirandom()`).

obs_pred_corr_dat <- contr_res %>%
  inner_join(sel_perts, by = c("dataset_name", "seed", "method", "perturbation")) %>%
  filter(method %in% names(method_labels)) %>%
  mutate(method = factor(method, levels = names(method_labels))) %>%
  left_join(approach_annot, by = "method") %>%
  mutate(perturbation = fct_reorder(perturbation, -l2)) %>%
  tidylog::left_join(baselines, by = c("dataset_name", "seed")) %>%
  dplyr::select(-c(id, test_train_config_id)) %>%
  mutate(gene_name = map(prediction, names)) %>%
  unnest(c(gene_name, prediction, observed, baseline)) %>%
  inner_join(expr_rank_df %>% dplyr::select(dataset_name, seed, gene_name, expr_rank) %>% 
               filter(expr_rank <= 1000), by = c("dataset_name", "seed", "gene_name")) %>%
  mutate(obs_minus_baseline = clamp(observed - baseline, min = -1, max = 1),
         pred_minus_baseline = clamp(prediction - baseline, min = -1, max = 1)) 
## left_join: added one column (baseline)
##            > rows only in x          0
##            > rows only in baselines (4)
##            > matched rows            9
##            >                        ===
##            > rows total              9
obs_pred_corr_pl <- obs_pred_corr_dat|>
  ggplot(aes(x = obs_minus_baseline, y = pred_minus_baseline)) +
    geom_abline(linewidth = 0.2, linetype = "dashed") +
    ggrastr::rasterize(geom_point(size = 0.5, stroke = 0), dpi = 600) +
    annotate("rect", xmin = -0.95, ymin = 0.5, xmax = 0.3, ymax = Inf, fill = "white", alpha = 0.8) +
    geom_text(data = . %>% summarize(l2 = first(l2), .by = c(method, perturbation, approach)), aes(label = paste0("$L_2$: ", round(l2, 1))),
              x = -0.95, y = Inf, hjust = 0, vjust = 1.2, size = font_size_tiny / .pt) +
    geom_text(data = . %>% summarize(r2_delta = first(r2_delta), .by = c(method, perturbation, approach)), 
              aes(label = paste0("$R^2$: ", round(r2_delta, 2))),
              x = -0.95, y = Inf, hjust = 0, vjust = 2.5, size = font_size_tiny / .pt) +
    coord_fixed(xlim = c(-1, 1), ylim = c(-1, 1)) +
    scale_x_continuous(breaks = c(-1, 0, 1)) +
    scale_y_continuous(breaks = c(-1, 0, 1)) +
    # ggh4x::facet_nested_wrap(vars(approach, method), nest_line = TRUE, 
    ggh4x::facet_wrap2(vars(method),
               nrow  = 3, # ncol = 2, 
                             labeller = labeller(approach = as_labeller(approach_labels), method = as_labeller(method_labels)),
                             strip = ggh4x::strip_nested(clip = "off")) +
    labs(x = "Observed minus control expression", y = "Predicted minus control expression")

obs_pred_corr_pl

sel_ranks <- c(seq(1, 100, by = 1), seq(101, 1000, by = 10), seq(1001, 19264, by = 100))

# For correlation, I could use the TTR::runCor function, but it is slow
strat_data_init <- contr_res %>%
  filter(train != "train") %>%
  tidylog::left_join(baselines, by = c("dataset_name", "seed")) %>%
  dplyr::select(-c(id, test_train_config_id, prediction_std, epochs)) %>%
  mutate(gene_name = map(prediction, names)) %>%
  unnest(c(gene_name, prediction, observed, baseline))
## left_join: added one column (baseline)
##            > rows only in x              0
##            > rows only in baselines (    0)
##            > matched rows            3,100
##            >                        =======
##            > rows total              3,100
strat_data_expr_rank <- strat_data_init %>%
  inner_join(expr_rank_df %>% dplyr::select(dataset_name, seed, gene_name, rank = expr_rank),
            by = c("dataset_name", "seed", "gene_name")) %>%
  arrange(rank) %>%
  mutate(dist = sqrt(cumsum((prediction - observed)^2)),
         .by = c(dataset_name, seed, method, perturbation)) %>% 
  filter(rank %in% sel_ranks)


strat_data_de_rank <- strat_data_init %>%
  left_join(de_rank_df %>% dplyr::select(dataset_name, seed, perturbation, gene_name, rank = de_rank), 
            by = c("dataset_name", "seed", "gene_name", "perturbation")) %>%
  arrange(rank) %>%
  mutate(dist = sqrt(cumsum((prediction - observed)^2)),
         .by = c(dataset_name, seed, method, perturbation))%>% 
  filter(rank %in% sel_ranks)
strat_merged <- bind_rows(
  strat_data_expr_rank %>% mutate(sorted_by = "expr"),
  strat_data_de_rank %>% mutate(sorted_by = "de")
) %>%
  mutate(sorted_by = factor(sorted_by, levels = c("expr", "de"))) %>%
  mutate(norm_dist = dist / rank) |>
  summarize(dist_mean = mean(dist),
            dist_se = sd(dist) / sqrt(first(rank)),
            .by = c(method, dataset_name, rank, sorted_by)) 

ggplot_colors_five <- colorspace::qualitative_hcl(length(method_labels), h = c(0, 270), c = 60, l = 70)
names(ggplot_colors_five) <- names(method_labels)

strat_pl <- strat_merged %>%
  filter(method %in% names(method_labels)) %>%
  mutate(method = factor(method, levels = names(method_labels))) %>%
  mutate(custom_vjust = case_when(
    method == "gears" ~ -0.3,
    method == "scgpt" ~ 0.1,
    method == "uce" ~ 1.2,
    method == "scbert" ~ 0.1,
    .default = 0.5
  )) |>
  ggplot(aes(x = rank, y = dist_mean)) +
    ggrastr::rasterize(geom_line(aes(color = method), show.legend=FALSE), dpi = 600) +
    geom_text(data = . %>% filter(method != "cpa") %>% filter(rank == max(rank)), 
              aes(label = method_labels[method], color = stage(method, after_scale = colorspace::darken(color, 0.5)), 
                  vjust = custom_vjust),
              hjust = 0, size = font_size_small / .pt, show.legend = FALSE) +
    geom_text(data = . %>% filter(method == "cpa") %>% slice_min(ifelse(dist_mean > 14, dist_mean - 14, Inf), by = sorted_by), 
              aes(label = method_labels[method], color = stage(method, after_scale = colorspace::darken(color, 0.5))),
              y = 14, vjust = 1, hjust = -0.2, size = font_size_small / .pt, show.legend = FALSE) +
    geom_vline(data = tibble(rank = 1000, sorted_by = factor("expr", levels = c("expr", "de"))), aes(xintercept = rank),
               linewidth = 0.4, linetype = "dashed", color = "grey") +
    scale_x_log10(labels = scales::label_comma(), limits = c(1, NA), expand = expansion(mult = c(0, 0.1))) +
    # scale_y_continuous(limits = c(0, NA), expand = expansion(mult = c(0, 0)), breaks = c(0, 1, 2, 5, 10, 20), transform = scales::asinh_trans()) +
    scale_y_continuous(expand = expansion(add = 0)) +
    scale_color_manual(values = ggplot_colors_five) +
    facet_wrap(vars(sorted_by), scales = "free_y",
               labeller = as_labeller(c("expr" = "genes sorted by expression", "de" = "genes sorted by differential expression"))) +
    labs(x = "top $n$ genes (log-scale)", y = "Prediction error ($L_2$)") +
    coord_cartesian(clip = "off", ylim = c(0, 14)) +
    theme(panel.spacing.x = unit(14, "mm"))

strat_pl

plot_assemble(
  add_text("(A) Double perturbation prediction correlation", 
           x = 2.7, y = 1, fontsize = font_size, vjust = 1, fontface = "bold"),
  add_plot(main_pl_double_pearson, x = 3, y = 4, width = 130, height = 47.5),
  
  add_text("(B) Prediction error stratified by the considered gene sets",
           x = 2.7, y = 53, fontsize = font_size, vjust = 1, fontface = "bold"),
  add_plot(strat_pl, x = 3, y = 55, width = 100, height = 80),
  
  width = 170, height = 135, units = "mm", show_grid_lines = FALSE,
  latex_support = TRUE, filename = "../plots/suppl-pearson_delta_performance.pdf"
)
## Using TikZ metrics dictionary at:
##  double_perturbation_analysis-tikzDictionary
## gg[gg1]
## Warning: Removed 326 rows containing missing values or values outside the scale
## range (`position_quasirandom()`).
## Warning: Removed 1 row containing missing values or values outside the scale
## range (`geom_hpline()`).
## gg[gg2]
## gg[gg3]
## gg[gg4]
## [1] TRUE TRUE
writexl::write_xlsx(list(
  "Panel A"  = main_pl_data,
  "Panel B" = strat_merged
  ), path = "../source_data/suppl-pearson_delta_performance.xlsx"
)
all_combs <- tibble(perturbation = res$perturbation |> discard(\(x) str_detect(x, "ctrl")) |> unique()) %>%
  mutate(split = str_split(perturbation, "\\+")) %>%
  mutate(combs = map(split, \(x) list(x, c(x[1], "ctrl"), c(x[2], "ctrl"), "ctrl")),
         labels = map(split, \(x) c("AB", "A", "B", "ctrl"))) %>%
  transmute(pert_group = perturbation, 
            combs = map(combs, \(x) map(x, sort, method = "radix")),
            labels) %>%
  unnest(c(combs, labels))

ground_truth_df <- res %>%
  filter(method == "ground_truth") %>%
  mutate(perturbation_split = map(perturbation_split, sort, method = "radix")) %>%
  dplyr::select(perturbation, perturbation_split, seed, train, ground_truth = prediction) %>%
  inner_join(all_combs, by = c("perturbation_split" = "combs"), relationship = "many-to-many") %>% 
  unnest_named_lists(ground_truth, names_to = "gene_name") %>%
  pivot_wider(id_cols = c(gene_name, pert_group, seed), names_from = labels, values_from = ground_truth) %>%
  mutate(error = `AB` - (A + B - ctrl))
filter_gt_df <- ground_truth_df %>%
  filter(seed == 1) |>
  inner_join(expr_rank_df %>% dplyr::select(gene_name, seed, rank = expr_rank) %>% filter(rank <= 1000) , by = c("gene_name", "seed")) 
set.seed(1)
locfdr_est <- locfdr::locfdr(filter_gt_df$error, nulltype = 1)
## Warning: glm.fit: fitted rates numerically 0 occurred
## Warning in locfdr::locfdr(filter_gt_df$error, nulltype = 1): f(z) misfit =
## 76.7.  Rerun with increased df

locfdr_est$z.2
## [1] -0.1597242  0.1591445
locfdr_est$fp0
##               delta        sigma          p0
## thest  0.0000000000 1.0000000000 8.826975628
## theSD  0.0000000000 0.0000000000 0.015472160
## mlest  0.0069048519 0.0588595687 0.841291742
## mleSD  0.0003337505 0.0005341399 0.004987568
## cmest -0.0005052317 0.0839470793 0.988610934
## cmeSD  0.0002836887 0.0001713179 0.000869025
mean_est <- locfdr_est$fp0["mlest","delta"]
sd_est <- locfdr_est$fp0["mlest","sigma"]
p0_est <- locfdr_est$fp0["mlest", "p0"]

upper_thres <- tibble(deviation = filter_gt_df$error, fdr = locfdr_est$fdr) %>%
  filter(deviation > 0) %>%
  slice_min(abs(fdr - 0.05), with_ties = FALSE) %>%
  pull(deviation)

lower_thres <- tibble(deviation = filter_gt_df$error, fdr = locfdr_est$fdr) %>%
  filter(deviation < 0) %>%
  slice_min(abs(fdr - 0.05), with_ties = FALSE) %>%
  pull(deviation) 

upper_thres
##           
## 0.2017037
lower_thres
##            
## -0.2138857
annotate_ticks <- function(origin = c(0,0), dir = c(1,0), at = seq(-10, 10), length = 0.1, ...){
  orth_dir <- c(dir[2], -dir[1])
  pos <- t(lemur:::mply_dbl(at, \(t) origin + t * dir, ncol=2))
  start <- pos + length/2 * orth_dir
  end <- pos - length/2 * orth_dir
  dat <- tibble(pos = t(pos), start = t(start), end = t(end))
  geom_segment(data = dat, aes(x = start[,1], xend = end[,1], y = start[,2], yend = end[,2]), ...)
}

annotate_labels_along <- function(origin = c(0,0), dir = c(1,0), labels = at, at = 0, offset = 0, extra_df = NULL, ...){
  orth_dir <- c(dir[2], -dir[1])
  pos <- t(lemur:::mply_dbl(at, \(t) origin + t * dir, ncol=2))
  dat <- bind_cols(tibble(pos = t(pos), labels), extra_df)
  angle <- atan2(dir[2], dir[1]) / pi * 180
  geom_text(data=dat, aes(label = labels, x = pos[,1] + offset * orth_dir[1], y = pos[,2] + offset * orth_dir[2]), angle = angle, ...)
}

label_pos <- c(0.001, 0.01, 0.1, 0.2, 0.5, 0.8, 0.9, 0.99, 0.999)

qq_pl <- filter_gt_df %>%
  mutate(percent_rank = percent_rank(error)) %>%
  arrange(error) %>%
  mutate(expect_quantile = qnorm(ppoints(n()))) %>%
  ggplot(aes(x = expect_quantile, y = error)) +
    geom_abline(slope = sd_est) +
    annotate_ticks(dir = c(1, sd_est), at  = qnorm(label_pos), length = 0.17) +
    annotate_labels_along(dir = c(1, sd_est), at = qnorm(label_pos[1:4]), labels = label_pos[1:4], offset = -0.25, size = font_size_small / .pt) +
    annotate_labels_along(dir = c(1, sd_est), at = qnorm(label_pos[5:9]), labels = label_pos[5:9], offset = 0.25, size = font_size_small / .pt) +
    annotate_labels_along(dir = c(1, sd_est), at = 4.5, labels = "Percentile", offset = 0.15, size = font_size_small / .pt) +
    ggrastr::rasterize(geom_point(size = 0.3, stroke = 0), dpi = 300) +
    coord_fixed() +
    labs(x = "Quantiles of a standard normal distribution", y = "Quantiles of the observed \nexpression minus additive expectation")
qq_pl

bin_numeric <- function(label){
  mat <- str_match(label, "^[\\(\\[]([+-]?\\d+\\.?\\d*),\\s*([+-]?\\d+\\.?\\d*)[\\]\\)]$")[,2:3,drop=FALSE]
  array(as.numeric(mat), dim(mat))
}

slice_first <- function(data, condition, order_by = row_number(), ...){
  filtered_data <- filter(data, {{condition}})
  filtered_data <- arrange(filtered_data, {{order_by}})
  slice_head(filtered_data, ...)
}

dens_ratio_df <- filter_gt_df %>%
  filter(seed == 1) %>%
  mutate(obs_dens = error |> (\(err){
    dens <- density(err, bw = "nrd0")
    approx(dens$x, dens$y, err)$y
  })(),
  expected_dens = p0_est * dnorm(error, mean = mean_est, sd = sd_est)) %>%
  mutate(ratio =  pmin(1, expected_dens / obs_dens)) 

# upper_thres <- dens_ratio_df %>% slice_first(ratio < 0.1 & error > 0, order_by = error) %>% pull(error)
# lower_thres <- dens_ratio_df %>% slice_first(ratio < 0.1 & error < 0, order_by = desc(error)) %>% pull(error)

count_labels <- filter_gt_df %>%
  mutate(label = case_when(
    error > upper_thres ~ "synergy",
    error < lower_thres ~  "suppressive",
    .default = "additive"
  )) %>%
  count(label) %>%
  mutate(n = scales::label_comma()(n)) %>%
  left_join(enframe(c(additive = 0, suppressive = -0.4, synergy = 0.4), name = "label", value = "pos"))
## Joining with `by = join_by(label)`
error_histogram <- dens_ratio_df %>%
  mutate(error_bin = santoku::chop_width(error, width = 0.01)) %>%
  mutate(bin_num = bin_numeric(as.character(error_bin))) %>%
  summarize(count_h0 = n() * mean(ratio),
            count_h1 = n() * (1-mean(ratio)),
            .by = c(error_bin, bin_num)) %>%
  pivot_longer(starts_with("count_"), names_sep = "_", names_to = c(".value", "origin"))  %>%
  mutate(origin = factor(origin, levels = c("h1", "h0"))) %>%
  mutate(bin_width = matrixStats::rowDiffs(bin_num)) %>%
  ggplot(aes(x = rowMeans(bin_num), y = count / sum(count) / bin_width)) +
    geom_col(aes(fill = origin), width = 0.01, position = "stack", show.legend = FALSE) +
    geom_function(fun = \(x)  p0_est * dnorm(x, mean = mean_est, sd = sd_est), n = 1e4, color = "red") +
    geom_vline(xintercept = c(lower_thres, upper_thres), color = "#040404", linewidth = 0.2) +
    geom_text(data = count_labels, aes(x = pos, y = Inf, label = n), hjust = 0.5, vjust = 1.2, size = font_size_small / .pt) +
    scale_fill_manual(values = c("h0" = "lightgrey", "h1" = "black")) +
    scale_x_continuous(limits = c(-0.45, 0.45)) +
    scale_y_continuous(expand = expansion(mult = c(0, 0.1))) +
    labs(y = "density", x = "Observed LFC over additive expectation")

error_histogram
## Warning: `position_stack()` requires non-overlapping x intervals.
## Warning: Removed 158 rows containing missing values or values outside the scale
## range (`geom_col()`).

plot_assemble(
  add_text("(A) Quantile-Quantile plot of the difference from the additive expectation", 
           x = 2.7, y = 1, fontsize = font_size, vjust = 1, fontface = "bold"),
  add_plot(qq_pl, x = 3, y = 2, width = 120, height = 47.5),
  
  add_text("(B) Empirical null decomposition", x = 124.7, y = 1, fontsize = font_size, vjust = 1, fontface = "bold"),
  add_plot(error_histogram, x = 125, y = 4.5, width = 50, height = 41.5),
  
  width = 180, height = 50, units = "mm", show_grid_lines = FALSE,
  latex_support = TRUE, filename = "../plots/suppl-qqplot.pdf"
)
## gg[gg1]
## gg[gg2]
## gg[gg3]
## Warning: `position_stack()` requires non-overlapping x intervals.
## Warning: Removed 158 rows containing missing values or values outside the scale
## range (`geom_col()`).
## gg[gg4]
## [1] TRUE
writexl::write_xlsx(filter_gt_df, path = "../source_data/suppl-qqplot.xlsx")
filter_gt_df %>%
  count(error > upper_thres, error < lower_thres)
## # A tibble: 3 × 3
##   `error > upper_thres` `error < lower_thres`      n
##   <lgl>                 <lgl>                  <int>
## 1 FALSE                 FALSE                 118965
## 2 FALSE                 TRUE                    1408
## 3 TRUE                  FALSE                   3627
non_additive_colors <- c("Additive" = "lightgrey", "Other" = "#767676", "Non-additive" = "#00BA38", "Synergy" = "#fdc086", 
                         "Buffering" = "#beaed4", "Cryptic" = colorspace::darken("#f0027f", 0.5))



gene_response_label_df_intermed <- ground_truth_df |>
  inner_join(expr_rank_df %>% dplyr::select(gene_name, seed, expr_rank) %>% filter(expr_rank <= 1000), by = c("gene_name", "seed")) %>%
  mutate(add =  (A + B - ctrl)) |>
  mutate(pert_same_dir = sign(A - ctrl) == sign(B - ctrl) & sign(AB - ctrl) == sign(A - ctrl)) %>%
  mutate(label = case_when(
    pert_same_dir & ctrl < (A+B-ctrl) & AB - (A + B - ctrl) > upper_thres ~ "Synergy",
    pert_same_dir & ctrl > (A+B-ctrl) & AB - (A + B - ctrl) < lower_thres ~ "Synergy",
    pert_same_dir & ctrl < (A+B-ctrl) & AB - (A + B - ctrl) < lower_thres & AB < ctrl ~ "Cryptic",
    pert_same_dir & ctrl > (A+B-ctrl) & AB - (A + B - ctrl) < lower_thres & AB > ctrl ~ "Cryptic",
    pert_same_dir & ctrl < (A+B-ctrl) & AB - (A + B - ctrl) < lower_thres ~ "Buffering",
    pert_same_dir & ctrl > (A+B-ctrl) & AB - (A + B - ctrl) > upper_thres ~ "Buffering",
     AB - (A + B - ctrl) > upper_thres |  AB - (A + B - ctrl) < lower_thres ~ "Other",
    .default = "Additive"
  )) 


gene_response_label_df <- gene_response_label_df_intermed |>
  mutate(label = factor(label, levels = c("Additive", "Other", "Buffering", "Synergy", "Cryptic")))
non_additive_counts <- gene_response_label_df %>%
  filter(seed == 1) |>
  dplyr::count(label) %>%
  mutate(frac = n / sum(n)) |>
  mutate(is_additive = ifelse(label == "Additive", "Additive", "Non-additive"))

print(non_additive_counts)
## # A tibble: 4 × 4
##   label          n    frac is_additive 
##   <fct>      <int>   <dbl> <chr>       
## 1 Additive  118965 0.959   Additive    
## 2 Other       1396 0.0113  Non-additive
## 3 Buffering   2878 0.0232  Non-additive
## 4 Synergy      761 0.00614 Non-additive
perc_additive <- filter(non_additive_counts, label == "Additive")$frac

non_add_pl1 <- non_additive_counts %>%
  summarize(frac = sum(frac), .by = is_additive) |>
  mutate(start = cumsum(lag(frac, default = 0)),
         end = cumsum(frac)) |>
  ggplot(aes(ymin = start, ymax = end, xmin = 0.4, xmax = 0.6)) +
    geom_rect(aes(fill = is_additive), show.legend = FALSE) +
    scale_y_continuous(limits = c(0, 1), labels = \(x) paste0(x * 100, "\\%"), expand = expansion(add = 0), position = "left") +
    scale_fill_manual(values = non_additive_colors) +
    guides(y.sec = legendry::compose_stack(legendry::primitive_bracket(key = legendry::key_range_manual(start = 0, end = perc_additive, name = "Additive"), angle = -90))) +
    theme(legendry.bracket = element_blank(),
          legendry.bracket.size = unit(0, "pt"),
          axis.text.x = element_blank(),
          axis.ticks.x = element_blank(),
          axis.title.x = element_blank(),
          axis.line.x = element_blank()) 

non_additive_counts_helper_df <- non_additive_counts %>%
  arrange(label) %>%
  mutate(start = cumsum(lag(frac, default = 0)),
         end = cumsum(frac)) |>
  mutate(label = fct_rev(label)) |>
  filter(label != "Additive")

non_additive_annot_key <- legendry::key_range_manual(start = non_additive_counts_helper_df$start, end = non_additive_counts_helper_df$end, name = non_additive_counts_helper_df$label)
non_additive_annot_key$.level <- 1

non_add_pl2 <- non_additive_counts_helper_df |>
  ggplot(aes(ymin = start, ymax = end, xmin = 0.4, xmax = 0.6)) +
    geom_rect(aes(fill = label), show.legend = FALSE) +
    scale_y_continuous(limits = c(NA, 1), labels = \(x) paste0(x * 100, "\\%"), expand = expansion(add = 0), position = "left") +
    scale_fill_manual(values = non_additive_colors) +
    guides(y.sec = legendry::compose_stack(legendry::primitive_bracket(key = non_additive_annot_key, angle = -90))) +
    theme(zoom = element_rect(fill = "grey"),
          legendry.bracket = element_blank(),
          legendry.bracket.size = unit(0, "pt"),
          axis.text.x = element_blank(),
          axis.ticks.x = element_blank(),
          axis.title.x = element_blank(),
          axis.line.x = element_blank())
non_add_pl1

non_add_pl2

inter_pred_dat <- res %>%
  filter(train %in% c("test", "val")) %>%
  filter(lengths(map(perturbation_split, \(x) setdiff(x, "ctrl"))) == 2) %>%
  filter(method %in% c("ground_truth", names(method_labels))) %>%
  tidylog::left_join(baselines, by = c("dataset_name", "seed")) %>%
  dplyr::select(perturbation, method, seed, prediction, baseline) %>%
  mutate(gene_name = map(prediction, names)) %>%
  unnest(c(gene_name, prediction, baseline)) %>%
  inner_join(expr_rank_df %>% dplyr::select(gene_name, seed, expr_rank) %>% filter(expr_rank <= 1000), by = c("gene_name", "seed")) %>%
  pivot_wider(id_cols = c(perturbation, gene_name, baseline, seed), names_from = method, values_from = prediction) %>%
  mutate(ref = additive_model) %>%
  pivot_longer(c(scgpt, gears, scfoundation, scbert, geneformer, uce, cpa, additive_model, no_change), names_to = "method") %>%
  mutate(obs_minus_add = ground_truth - ref,
         pred_minus_add = value - ref) %>%
  mutate(method = factor(method, levels = names(method_labels)))
## left_join: added one column (baseline)
##            > rows only in x              0
##            > rows only in baselines (    0)
##            > matched rows            3,100
##            >                        =======
##            > rows total              3,100
pert_pred_comparison_df <- inter_pred_dat %>%
  filter(method %in% names(method_labels)) %>%
  mutate(method = factor(method, levels = names(method_labels))) %>%
  left_join(approach_annot, by = "method") %>%
  tidylog::inner_join(gene_response_label_df|> dplyr::select(gene_name, pert_group, seed, interaction_label = label), by = c("perturbation" = "pert_group", "gene_name", "seed"))
## inner_join: added one column (interaction_label)
##             > rows only in x                         (        0)
##             > rows only in dplyr::select(gene_resp.. (  310,000)
##             > matched rows                            2,790,000
##             >                                        ===========
##             > rows total                              2,790,000
pert_pred_comparison_df %>%
  filter(interaction_label == "Synergy") |>
  count(bottom_left = obs_minus_add < lower_thres & pred_minus_add < lower_thres,
         top_left = obs_minus_add < lower_thres & pred_minus_add > upper_thres,
         top_right = obs_minus_add > upper_thres & pred_minus_add > upper_thres,
         bottom_right = obs_minus_add > upper_thres & pred_minus_add < lower_thres,
        method) |>
  pivot_longer(-c(method, n), names_to = "corner", values_to = "exist") |>
  filter(exist) |>
  dplyr::select(-exist) |>
  print(n = 50)
## # A tibble: 26 × 3
##    method           n corner      
##    <fct>        <int> <chr>       
##  1 no_change      321 bottom_right
##  2 scgpt          255 bottom_right
##  3 scfoundation   202 bottom_right
##  4 uce            250 bottom_right
##  5 scbert         250 bottom_right
##  6 geneformer      69 bottom_right
##  7 gears          147 bottom_right
##  8 cpa             75 bottom_right
##  9 scfoundation    15 top_right   
## 10 geneformer      78 top_right   
## 11 gears           36 top_right   
## 12 cpa            133 top_right   
## 13 no_change      714 top_left    
## 14 scgpt          405 top_left    
## 15 scfoundation    90 top_left    
## 16 uce            251 top_left    
## 17 scbert         251 top_left    
## 18 geneformer     161 top_left    
## 19 gears          158 top_left    
## 20 cpa            821 top_left    
## 21 scfoundation   185 bottom_left 
## 22 uce              1 bottom_left 
## 23 scbert           1 bottom_left 
## 24 geneformer     204 bottom_left 
## 25 gears          153 bottom_left 
## 26 cpa            116 bottom_left
pert_pred_comparison <- pert_pred_comparison_df %>%
  mutate(most_non_additive = rank(desc(abs(pred_minus_add))) <= 500, .by = c(method)) |>
  mutate(pred_minus_add = clamp(pred_minus_add, max = 1.37)) |>
  arrange(interaction_label) |>
  ggplot(aes(x = obs_minus_add, y = pred_minus_add)) +
    # geom_point(data = tibble(interaction_label = "Below baseline"), aes(x= 0, y = 0, color = interaction_label), stroke = 0, size = 0) +
    ggrastr::rasterize(geom_point(aes(color = interaction_label, alpha = most_non_additive, size = most_non_additive), stroke = 0), dpi = 600) +
    geom_abline(slope = 1, intercept = 0, linetype = "dashed", alpha = 0.3) +
    geom_vline(xintercept = c(lower_thres, upper_thres), linewidth = 0.2) +
    scale_x_continuous(expand = expansion(add = 0), breaks = c(-0.8, 0, 0.8)) +
    scale_y_continuous(expand = expansion(add = 0), breaks = c(-1, 0, 1)) +
    scale_color_manual(values = non_additive_colors, 
                       labels = c("Additive", "Other", "Buffering", "Synergy", "Cryptic"), drop = TRUE) +
    scale_alpha_manual(values = c("TRUE" = 1, "FALSE" = 0.1)) +
    scale_size_manual(values = c("TRUE" = 0.6, "FALSE" = 0.1)) +
    ggh4x::facet_nested_wrap(vars(approach, method), nrow = 1,
                             labeller = labeller(approach = as_labeller(approach_labels), 
                                                 method = as_labeller(method_labels)),
                             nest_line = TRUE, strip = ggh4x::strip_nested(clip = "off")) +
    coord_fixed(xlim = c(-1, 1), ylim = c(-1.4, 1.4)) +
    labs(x = "Observed expression minus additive expectation", 
       y = "Predicted expr.\\ minus\nadditive expectation", 
       color = "") +
    guides(color = guide_legend(override.aes = list(size = 2)), alpha = "none", size = "none") +
    theme(panel.spacing.x = unit(2, "mm"), legend.position = "bottom")
pert_pred_comparison

approx2 <- function(x, y, ...){
  data <- tibble({{x}}, {{y}})
  tmp <- as_tibble(approx(data[[1]], data[[2]], ...))
  colnames(tmp) <- colnames(data)
  tmp
}

tp_fdp_prec_recall_data_pre <- inter_pred_dat %>%
  tidylog::left_join(gene_response_label_df|> dplyr::select(gene_name, pert_group, seed,pert_same_dir,interaction_label = label), by = c("perturbation" = "pert_group", "gene_name", "seed")) |>
  filter(method %in% names(method_labels)) %>%
  filter(method != "additive_model") |>
  mutate(method = factor(method, levels = names(method_labels))) %>%
  left_join(approach_annot, by = "method") %>%
  mutate(true_nonadditive = obs_minus_add > upper_thres | obs_minus_add < lower_thres) %>%
  group_by(method, seed) %>%
  arrange(desc(abs(pred_minus_add))) %>%
  mutate(tp = cumsum(true_nonadditive),
         fp = cumsum(! true_nonadditive)) %>%
  mutate(fdp = fp / pmax(1, fp + tp),
         fpr = fp / sum(! true_nonadditive),
         precision = tp / (tp + fp),
         recall = tp / sum(true_nonadditive)) %>%
  arrange(fdp) %>%
  mutate(tp = cummax(tp)) %>%
  mutate(tpr = tp / sum(true_nonadditive)) |>
  ungroup()
## left_join: added 2 columns (pert_same_dir, interaction_label)
##            > rows only in x                                  0
##            > rows only in dplyr::select(gene_resp.. (  310,000)
##            > matched rows                            2,790,000
##            >                                        ===========
##            > rows total                              2,790,000
tp_fdp_data <- tp_fdp_prec_recall_data_pre %>%
  group_by(method, seed) %>%
  reframe(tmp = approx2(fdp, tpr, xout = seq(0, 1, length.out = 101), yleft = 0, yright = 1)) %>%
  ungroup() %>%
  unnest(tmp) %>%
  summarize(tpr = mean(tpr), .by = c(method, fdp)) %>%
  mutate(method = factor(method, levels = names(method_labels))) 
## Warning: There were 40 warnings in `reframe()`.
## The first warning was:
## ℹ In argument: `tmp = approx2(...)`.
## ℹ In group 1: `method = no_change` and `seed = 1`.
## Caused by warning in `regularize.values()`:
## ! collapsing to unique 'x' values
## ℹ Run `dplyr::last_dplyr_warnings()` to see the 39 remaining warnings.
colors_adapted <- approach_annot |> 
  mutate(color = ggplot_colors_five[method]) |>
  mutate(label = paste0(method_labels[as.character(method)], "|", approach_labels[as.character(approach)])) |>
  mutate(label = fct_reorder(label, as.integer(approach) * 1000 + as.integer(method) ))

label_pos <- c(no_change = 0.4, scgpt = 0.25, uce = 0.25, scfoundation = 0.5,
               gears = 0.4, geneformer = 0.3, cpa = 0.7, scbert = 0.32)

tp_fdp_pl <- tp_fdp_data %>%
  left_join(approach_annot, by = "method") %>%
  mutate(approach = fct_relevel(approach, "baseline", "deep_learning", "foundation_model")) %>%
  mutate(label = paste0(method_labels[as.character(method)], "|", approach_labels[as.character(approach)])) |>
  mutate(label = fct_reorder(label, as.integer(approach) * 1000 + as.integer(method) )) %>%
  ggplot(aes(x = fdp, y = tpr)) +
    geom_line(aes(color = label)) +
    ggrepel::geom_text_repel(data = . %>% mutate(annot = ifelse(rank((fdp - label_pos[as.character(method)])^2, ties.method = "first") == 1, method_labels[as.character(method)],  ""), .by = method),
                              aes(label = annot), size = font_size_tiny / .pt,
                             min.segment.length = 0, box.padding = 0.5, point.padding = 0,
                             color = "black", bg.colour  = "white", max.overlaps = Inf, seed = 1,
                             ylim = c(-Inf, Inf),
                             # label.size = NA, label.padding = unit(0.1, "mm")
                             ) +
    annotation_logticks(sides = "l", short = unit(0.3, "mm"), mid = unit(2/3, "mm"), long = unit(1, "mm")) +
    scale_color_manual(values = deframe(colors_adapted[c("label", "color")])) +
    scale_y_log10(breaks = c(0.01, 0.1, 1)) +
    scale_x_continuous(expand = expansion(add = 0)) +
    labs(x = "False discovery proportion ($\\frac{\\textrm{FP}}{\\textrm{FP}+\\textrm{TP}}$)",
         y = "True Positive Rate ($\\frac{\\textrm{TP}}{\\textrm{TP}+\\textrm{FN}}$)",
         color = "") +
    coord_cartesian(ylim = c(0.001, 1), xlim = c(0, 1)) + 
    theme(legend.position = "none", legend.key.height=unit(0.1,"mm"))

tp_fdp_pl

prediction_label_vs_true_df <- tp_fdp_prec_recall_data_pre |>
  slice_max(n = 500, order_by = abs(pred_minus_add), by = c(method), with_ties = FALSE) |>
  mutate(prediction_label = case_when(
    baseline < ref & ref > value & value >= baseline ~ "buf",
    baseline > ref & ref < value & value <= baseline ~ "buf",
    baseline < ref & ref < value                     ~ "syn",
    baseline > ref & ref > value                     ~ "syn",
    baseline < ref & baseline > value                ~ "anti",
    baseline > ref & baseline < value                ~ "anti",
    .default = "other"
  )) |>
  count(method, interaction_label, prediction_label) |>
  complete(method, prediction_label, interaction_label, fill = list(n = 0)) |>
  summarize(n = median(n), .by = c(method, prediction_label, interaction_label)) |>
  mutate(frac = n / sum(n), .by = c(prediction_label, method)) 


prediction_label_vs_true_df |> filter(method =="geneformer" & prediction_label == "buf")
## # A tibble: 5 × 5
##   method     prediction_label interaction_label     n   frac
##   <fct>      <chr>            <fct>             <int>  <dbl>
## 1 geneformer buf              Additive             54 0.286 
## 2 geneformer buf              Other                16 0.0847
## 3 geneformer buf              Buffering           111 0.587 
## 4 geneformer buf              Synergy               8 0.0423
## 5 geneformer buf              Cryptic               0 0
mosaic_plot <- prediction_label_vs_true_df |>
  mutate(marginal_n = sum(n), .by = c(method, prediction_label)) |>
  mutate(marginal_frac = marginal_n / sum(marginal_n), .by = c(method, interaction_label)) |>
  mutate(interaction_label = fct_relevel(interaction_label, "Other", "Additive", "Buffering", "Synergy")) |>
  ggplot(aes(x = frac, y = prediction_label)) + 
    geom_col(aes(fill = interaction_label, width = ifelse(marginal_frac == 0, 0, pmax(0.1, 1.5 * marginal_frac)))) +
    shadowtext::geom_shadowtext(data = . %>% distinct(method, prediction_label, marginal_n),
              aes(label = paste0("n=", scales::label_comma()(round(marginal_n))),
                  # x = ifelse(marginal_n == 0, 0, Inf), hjust = ifelse(marginal_n == 0, 0, 0)),
                  x = 0.02, hjust = 0), color = "black", bg.colour = "white",
              vjust = 0.5, size = font_size_tiny / .pt, nudge_x = 0.01) +
    scale_x_continuous(labels = \(x) paste0(x * 100, "\\%"), breaks = c(0, 0.5, 1), expand = expansion(add = 0), position = "bottom") +
    scale_y_discrete(labels = c(buf = "Buffering", syn = "Synergy", anti = "Opposite")) +
    scale_fill_manual(values = non_additive_colors) +
    facet_wrap(vars(method), labeller = labeller(method = as_labeller(method_labels)), nrow = 1, scales = "fixed") +
    labs(y = "Predicted class", x = "Proportion of observed interaction classes") +
    guides(fill = "none") +
    coord_cartesian(ylim = c(0.4, 3.2), xlim = c(0,1), clip = "off", expand = FALSE) +
    theme(panel.spacing.x = unit(5, "mm"), panel.spacing.y = unit(1.5, "mm"),
          strip.text = element_text(margin = margin(2,2,b=0, 2, "pt"))
          )
## Warning in geom_col(aes(fill = interaction_label, width = ifelse(marginal_frac
## == : Ignoring unknown aesthetics: width
mosaic_plot
## Warning: Removed 40 rows containing missing values or values outside the scale
## range (`geom_col()`).

Make ROC and PRC

auprc_dat_labels <- tp_fdp_prec_recall_data_pre %>%
  arrange(precision) %>%
  summarize(auprc = -sum(zoo::rollmean(precision, k = 2) * diff(recall)),
            .by = c(method, seed)) %>%
  summarize(mean = mean(auprc),
            se = sd(auprc) / sqrt(n()),
            .by = method) %>%
  arrange(-mean) %>%
  transmute(method, label = paste0(method_labels[method], " ($", round(mean, digits = 2), "\\pm", round(se, digits = 2), "$)")) %>%
  deframe()

auc_dat_labels <- tp_fdp_prec_recall_data_pre %>%
  arrange(fpr) %>%
  summarize(auc = sum(zoo::rollmean(recall, k = 2) * diff(fpr)),
            .by = c(method, seed)) %>%
  summarize(mean = mean(auc),
            se = sd(auc) / sqrt(n()),
            .by = method) %>%
  arrange(-mean) %>%
  transmute(method, label = paste0(method_labels[method], " ($", round(mean, digits = 2), "\\pm", round(se, digits = 2), "$)")) %>%
  deframe()

prc_plot <- tp_fdp_prec_recall_data_pre %>%
  mutate(method = factor(method, levels = names(auprc_dat_labels))) %>%
  ggplot(aes(x = recall, y = precision)) +
    ggrastr::rasterize(geom_line(aes(color = method), linewidth = 0.2), dpi = 300) +
    scale_color_manual(values = ggplot_colors_five, labels = auprc_dat_labels) +
    scale_x_continuous(breaks = c(0, 0.25, 0.5, 0.75, 1), labels = as.character(c(0, 0.25, 0.5, 0.75, 1))) +
    facet_wrap(vars(seed), labeller = label_both, nrow = 1) +
    coord_fixed() +
    guides(color = guide_legend(override.aes = list(linewidth = 0.8))) +
    labs(y = "Precision ($\\frac{\\textrm{TP}}{\\textrm{TP} + \\textrm{FP}}$)", 
         x = "Recall ($\\textrm{TPR} = \\frac{\\textrm{TP}}{\\textrm{TP} + \\textrm{FN}}$)",
         color = "",
         title = "(A) Precision-Recall Curve (PRC)")

roc_plot <- tp_fdp_prec_recall_data_pre %>%
  mutate(method = factor(method, levels = names(auc_dat_labels))) %>%
  ggplot(aes(x = fpr, y = recall)) +
    ggrastr::rasterize(geom_line(aes(color = method), linewidth = 0.2), dpi = 300) +
    geom_abline(slope = 1, color = "lightgrey", linewidth = 0.8, linetype = "dashed") +
    scale_color_manual(values = ggplot_colors_five, labels = auc_dat_labels) +
    scale_x_continuous(breaks = c(0, 0.25, 0.5, 0.75, 1), labels = as.character(c(0, 0.25, 0.5, 0.75, 1))) +
    facet_wrap(vars(seed), labeller = label_both, nrow = 1) +
    coord_fixed() +
    guides(color = guide_legend(override.aes = list(linewidth = 0.8))) +
    labs(x = "False Positive Rate ($\\frac{\\textrm{FP}}{\\textrm{FP} + \\textrm{TN}}$)", 
         y = "Recall ($\\textrm{TPR} = \\frac{\\textrm{TP}}{\\textrm{TP} + \\textrm{FN}}$)",
         color = "",
         title = "(B) Receiver Operator Curve (ROC)")

prc_plot

roc_plot

plot_assemble(
  add_plot(prc_plot, x = 0, y = 0, width = 180, height = 50),
  add_plot(roc_plot, x = 0, y = 52, width = 180, height = 50),

  width = 180, height = 105, units = "mm", show_grid_lines = FALSE,
  latex_support = TRUE, filename = "../plots/suppl-roc_curves.pdf"
)
## gg[gg1]
## gg[gg2]
##  [1] TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE
# Data is too large and cannot be saved
# writexl::write_xlsx(list(
#   tp_fdp_prec_recall_data_pre
# ), path = "../source_data/suppl-roc_curves.xlsx")
cartoon_color_values <- as.character(wesanderson::wes_palette("FrenchDispatch", n = 3)[c(3,2,1)]) |>
  magrittr::set_names(c("ctrl", "Ad", "Bd"))

height_spec <- c(ctrl = 0.4, Ad = 0.1, Bd = 0.15)

  
training_df <- tibble(bar = c("ctrl", "A", "A", "B", "B"),
       part = c("ctrl", "ctrl", "Ad", "ctrl", "Bd")) |>
   mutate(height = height_spec[part])

train_plot <- training_df |>
  mutate(bar = factor(bar, c("ctrl", "A", "B"))) |>
  mutate(part = factor(part, c( "Bd", "Ad", "ctrl"))) |>
  arrange(bar, desc(part)) |>
  mutate(start_pos = cumsum(lag(height, default = 0)),
         end_pos = cumsum(height),
         .by = bar) |>
  ggplot(aes(x = bar, y = height)) +
    geom_col(aes(fill = part), width = 0.5) +
    annotate("segment", x = "A", y = height_spec["ctrl"], yend = height_spec["ctrl"] + height_spec["Ad"], color = colorspace::darken(cartoon_color_values["Ad"], 0.3),
             linewidth = 0.7, arrow = grid::arrow(type = "closed", length = unit(1, "mm"))) +
    annotate("segment", x = "B", y = height_spec["ctrl"], yend = height_spec["ctrl"] + height_spec["Bd"], color = colorspace::darken(cartoon_color_values["Bd"], 0.3),
             linewidth = 0.7, arrow = grid::arrow(type = "closed", length = unit(1, "mm"))) +
    scale_y_continuous(expand = expansion(add = 0)) +
    scale_x_discrete(labels = c(ctrl = "No Perturbation", A = "Perturbation A", B = "Perturbation B")) +
    scale_fill_manual(values = cartoon_color_values) +
    scale_color_manual(values = cartoon_color_values) +
    coord_cartesian(ylim = c(0, 0.8), clip = "off") +
    guides(color = "none", fill = "none", alpha = "none",
           x = guide_axis(angle = 90)) +
    labs(y = "Log Expression of a single gene",
         subtitle = "Training data") +
    theme(axis.title = element_blank(),
          axis.ticks.y.left = element_blank(),
          axis.text.y.left = element_blank())

categories_df <- tibble(add = sum(height_spec),
                        syn = add + 0.1,
                        buf = add - 0.12,
                        opp = height_spec["ctrl"] - 0.1) |>
  pivot_longer(everything(), names_to = "class", values_to = "height") |>
  mutate(class = factor(class, levels = c("add", "buf", "syn", "opp")))

arrow_df <- enframe(height_spec, name = "part", value = "height") |>
  mutate(start_pos = cumsum(lag(height, default = 0)),
           end_pos = cumsum(height)) |>
  mutate(class = list(categories_df$class)) |>
  unnest(class) |>
  mutate(part =factor(part, levels = c("ctrl", "Ad", "Bd"))) |>
  arrange(desc(part))


sum_height_spec1 <-  sum(height_spec)
rel_zero1 <- height_spec["ctrl"]
obs_plot <- categories_df |>
  ggplot(aes(x = class, y = height)) +
    geom_col(fill = "#5e5e5e", width = 0.5) +
    annotate(geom = "errorbar", x = 1.25, ymin = sum_height_spec1 - 0.05, ymax = sum_height_spec1 + 0.05, width = 0.1) +
    annotate(geom = "errorbar", x = 0.75, ymin = sum_height_spec1 - 0.05, ymax = sum_height_spec1 + 0.05, width = 0.1) +
    geom_segment(data = arrow_df, aes(color = part, y = start_pos, yend = end_pos), linewidth = 0.7,
                 arrow = grid::arrow(type = "closed", length = unit(1, "mm"))) +
    scale_y_continuous(expand = expansion(add = 0), 
                       sec.axis = sec_axis(transform = \(x) x - rel_zero1, 
                                           name = "LFC over additive expectation", breaks = c(0))) +
    scale_x_discrete(labels = c(add = "Additive\n(Non-interaction)", buf = "Buffering", syn = "Synergy", opp = "Opposite")) +
    scale_fill_manual(values = cartoon_color_values) +
    scale_color_manual(values = cartoon_color_values) +
    coord_cartesian(ylim = c(0, 0.8), clip = "off") +
    guides(color = "none", fill = "none", alpha = "none",
           x = guide_axis(angle = 90)) +
    labs(subtitle = "Observed / Predicted Interaction\nClass of Perturbation A+B") +
    theme(axis.title = element_blank(),
          axis.ticks.y.left = element_blank(),
          axis.text.y.left = element_blank(),
          axis.line.y.left = element_blank())

cowplot::plot_grid(train_plot, obs_plot, align = "h")

height_spec <- c(ctrl = 0.4, Ad = -0.1, Bd = -0.15)

  
training_df <- tibble(bar = c("ctrl", "A", "A", "B", "B"),
       part = c("ctrl", "ctrl", "Ad", "ctrl", "Bd")) |>
  mutate(height = height_spec[part]) |>
  mutate(bar = factor(bar, c("ctrl", "A", "B"))) |>
  mutate(part = factor(part, c( "Bd", "Ad", "ctrl"))) |>
  arrange(bar, desc(part)) |>
  mutate(start_pos = cumsum(lag(height, default = 0)),
         end_pos = cumsum(height),
         .by = bar)

train_plot2 <- training_df |>
 mutate(end_pos = min(end_pos), .by = bar) |>
  ggplot(aes(x = bar)) +
    geom_tile(aes(fill = part, alpha = part, y = (start_pos + end_pos) / 2, height = abs(end_pos - start_pos)), width = 0.5) +
    annotate("segment", x = "A", y = height_spec["ctrl"], yend = height_spec["ctrl"] + height_spec["Ad"], color = cartoon_color_values["Ad"],
             linewidth = 0.7, arrow = grid::arrow(type = "closed", length = unit(1, "mm"))) +
    annotate("segment", x = "B", y = height_spec["ctrl"], yend = height_spec["ctrl"] + height_spec["Bd"], color = cartoon_color_values["Bd"],
             linewidth = 0.7, arrow = grid::arrow(type = "closed", length = unit(1, "mm"))) +
    scale_y_continuous(expand = expansion(add = 0)) +
    scale_x_discrete(labels = c(ctrl = "No Perturbation", A = "Perturbation A", B = "Perturbation B")) +
    scale_fill_manual(values = cartoon_color_values) +
    scale_color_manual(values = cartoon_color_values) +
    scale_alpha_manual(values = c(ctrl = 1, Ad = 0.2, Bd = 0.2)) +
    coord_cartesian(ylim = c(0, 0.5), clip = "off") +
    guides(color = "none", fill = "none", alpha = "none",
           x = guide_axis(angle = 90)) +
        labs(y = "Log Expression of a single gene") +
    theme(axis.title = element_blank(),
          axis.ticks.y.left = element_blank(),
          axis.text.y.left = element_blank())

categories_df <- tibble(add = sum(height_spec),
                        syn = add - 0.1,
                        buf = add + 0.12,
                        opp = height_spec["ctrl"] + 0.03) |>
  pivot_longer(everything(), names_to = "class", values_to = "height") |>
  mutate(class = factor(class, levels = c("add", "buf", "syn", "opp")))

arrow_df <- enframe(height_spec, name = "part", value = "height") |>
  mutate(start_pos = cumsum(lag(height, default = 0)),
           end_pos = cumsum(height)) |>
  mutate(class = list(categories_df$class)) |>
  unnest(class) |>
  mutate(part =factor(part, levels = c("ctrl", "Ad", "Bd"))) |>
  arrange(desc(part))

sum_height_spec2 <-  sum(height_spec)
rel_zero2 <- height_spec["ctrl"]
obs_plot2 <- categories_df |>
  ggplot(aes(x = class, y = height)) +
    geom_col(fill = "#5e5e5e", width = 0.5) +
    annotate(geom = "errorbar", x = 1.25, ymin = sum_height_spec2 - 0.02, ymax = sum_height_spec2 + 0.02, width = 0.1) +
    annotate(geom = "errorbar", x = 0.75, ymin = sum_height_spec2 - 0.02, ymax = sum_height_spec2 + 0.02, width = 0.1) +
    geom_segment(data = arrow_df %>% arrange(1)%>% filter(part != "ctrl"), aes(color = part, y = start_pos, yend = end_pos), linewidth = 0.7,
                 arrow = grid::arrow(type = "closed", length = unit(1, "mm"))) +
    scale_y_continuous(expand = expansion(add = 0), 
                       sec.axis = sec_axis(transform = \(x) x - rel_zero2, 
                                           name = "LFC over additive expectation", breaks = c(0))) +
    scale_x_discrete(labels = c(add = "Additive\n(Non-interaction)", buf = "Buffering", syn = "Synergy", opp = "Opposite")) +
    scale_fill_manual(values = cartoon_color_values) +
    scale_color_manual(values = cartoon_color_values) +
    coord_cartesian(ylim = c(0, 0.5), clip = "off") +
    guides(color = "none", fill = "none", alpha = "none",
           x = guide_axis(angle = 90)) +
    theme(axis.title = element_blank(),
          axis.ticks.y.left = element_blank(),
          axis.text.y.left = element_blank(),
          axis.line.y.left = element_blank())

cowplot::plot_grid(train_plot2, obs_plot2, align = "h")

cartoon_pl <- cowplot::plot_grid(
  cowplot::plot_grid(train_plot + theme(axis.text.x = element_blank()), obs_plot + theme(axis.text.x = element_blank()), rel_widths = c(3, 4), align = "h"),
  cowplot::plot_grid(train_plot2, obs_plot2, rel_widths = c(3, 4), align = "h"),
  ncol = 1, rel_heights = c(1, 1.5), align = "v"
)
cartoon_pl

plot_assemble(
  add_text("(A) Double perturbation prediction error", x = 2.7, y = 1, fontsize = font_size, vjust = 1, fontface = "bold"),
  add_plot(main_pl_double_l2, x = 0, y = 4, width = 125, height = 60),

  add_text("(B) Example:\\;{\\scriptsize\\textcolor{baseROrange}\\faCircle}\\;CEBPE+CEBPB", x = 128, y = 1, fontsize = font_size, vjust = 1, fontface = "bold"),
  add_plot(obs_pred_corr_pl, x = 125, y = 4, width = 58, height = 60),

  add_text("(C) Accuracy of interaction predictions", x = 2.7, y = 66, fontsize = font_size, vjust = 1, fontface = "bold"),
  add_plot(tp_fdp_pl, x = 0, y = 70, width = 60, height = 58),
  
  add_text("(D) Classification of interactions", x = 62.7, y = 66, fontsize = font_size, vjust = 1, fontface = "bold"),
  add_plot(cartoon_pl, x = 65, y = 70, width = 65, height = 60),
  add_text("Log expression of example read-out gene", x = 64.5, y = 93, angle = 90, fontsize = font_size_small, vjust = 0.5, hjust = 0.5),
  add_text("LFC compared to control", x = 130, y = 93, angle = -90, fontsize = font_size_small, vjust = 0.5, hjust = 0.5),
  
  add_text("(E) Observed composition\nof interaction classes", x = 141, y = 66, fontsize = font_size, vjust = 1, fontface = "bold"),
  add_plot(non_add_pl1, x = 139, y = 76, width = 20, height = 45),
  add_plot(grid::polygonGrob(x = c(0.398, 0.65, 0.65, 0.398), y = c(0.969, 0.969, 0.05, 0.927),
                            gp = grid::gpar(fill = non_additive_colors["Non-additive"], alpha = 0.2, lty = 0)),
          x = 140, y = 76, width = 37, heigh = 45),
  add_plot(non_add_pl2, x = 156, y = 76, width = 20, height = 45),

  add_text("(F) Prediction of change relative to additive expectation and interaction class", x = 2.7, y = 131.5, fontsize = font_size, vjust = 1, fontface = "bold"),
  add_plot(pert_pred_comparison + guides(color = "none"), x = 4, y = 135, width = 176, height = 40),
  add_plot(mosaic_plot + theme(strip.text = element_blank(), axis.text.x = element_text(size = font_size_tiny)), x = 0, y = 173, width = 180, height = 20),
  add_plot(my_get_legend(pert_pred_comparison), x = 125, y = 131.5, width = 50, height = 5),

  width = 180, height = 193, units = "mm", show_grid_lines = FALSE,
  latex_support = TRUE, filename = "../plots/perturbation_prediction.pdf"
)
## gg[gg1]
## Warning: Removed 147 rows containing missing values or values outside the scale
## range (`position_quasirandom()`).
## gg[gg2]
## gg[gg3]
## gg[gg4]
## gg[gg5]
## gg[gg6]
## gg[gg7]
## gg[gg8]
## gg[gg9]
## gg[gg10]
## gg[gg11]
## gg[gg12]
## gg[gg13]
## gg[gg14]
## gg[gg15]
## gg[gg16]
## Warning: Removed 40 rows containing missing values or values outside the scale
## range (`geom_col()`).
## gg[gg17]
## gg[gg18]
##  [1] TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE
## [16] TRUE TRUE TRUE
writexl::write_xlsx(list(
  "Panel A" = main_pl_data,
  "Panel B" = obs_pred_corr_dat |> dplyr::select(-c(prediction_std, perturbation_split)),
  "Panel C" = tp_fdp_data,
  "Panel E" = non_additive_counts,
  "Panel F" = pert_pred_comparison_df %>%
    filter(rank(desc(abs(pred_minus_add))) <= 500, .by = c(method)) |>
    mutate(`NOTE THIS IS ONLY THE BOLD DATA POINTS` = "")
), path = "../source_data/perturbation_prediction.xlsx")
overlap_preparation_df <- tp_fdp_prec_recall_data_pre |>
  slice_head(n = 500, by = c(method)) |>
  summarize(.by = c(perturbation, gene_name, seed, true_nonadditive, interaction_label),
            methods = list(method_labels[method])) |>
  mutate(shared_deep_learning_exclusives = lengths(methods) >= 3 & map_lgl(methods, \(x) ! "No Change" %in% x))

upset_plot_pred_overlap <- overlap_preparation_df |>
  ggplot(aes(x = methods)) +
    geom_bar(aes(fill = shared_deep_learning_exclusives), show.legend = FALSE) +
    scale_y_continuous(expand = expansion(add = 0)) +
    ggupset::scale_x_upset(sets = unname(method_labels), n_intersections = 25) +
    theme(axis.title.x = element_blank()) +
    ggupset::theme_combmatrix(combmatrix.panel.point.size = 1,
                              combmatrix.panel.line.size = 0.8)

upset_plot_pred_overlap
## Warning: Removed 265 rows containing non-finite outside the scale range
## (`stat_count()`).

tab1 <- overlap_preparation_df |>
  filter(shared_deep_learning_exclusives) |>
  count(gene_name) |>
  slice_max(n, n = 10, with_ties = FALSE) |>
  dplyr::rename(`Gene Name` = gene_name) 

tab2 <- overlap_preparation_df |>
  filter(shared_deep_learning_exclusives) |>
  count(perturbation) |>
  slice_max(n, n = 10, with_ties = FALSE)  |>
  dplyr::rename(`Perturbation` = perturbation)

shared_deep_learning_exclusives_tab <- cbind(tab1, tab2) |>
  tinytable::tt() |>
  tinytable::style_tt(fontsize = 0.8) |>
  tinytable::style_tt(j = 2, line = "r") |>
  tinytable:::build_tt(output = "latex") %>%
  {
    str <- (.@table_string) |>
      str_remove_all("%%.*\n") |>
      str_remove(r"(\\begin\{table\})") |>
      str_remove(r"(\\end\{table\})") 
    paste0("\\begin{minipage}{\\textwidth}", str, "\\end{minipage}")
  }

cat(shared_deep_learning_exclusives_tab)
## \begin{minipage}{\textwidth}
## \centering
## \begin{tblr}[         ]                     {                     colspec={Q[]Q[]Q[]Q[]},
## column{1,2,3,4}={}{font=\fontsize{0.8em}{1.1em}\selectfont,},
## vline{3}={1,2,3,4,5,6,7,8,9,10,11}{solid, black, 0.1em},
## }                     \toprule
## Gene Name & n & Perturbation & n \\ \midrule HBZ & 59 & CEBPE+CEBPA & 36 \\
## HBG2 & 45 & CEBPB+CEBPA & 34 \\
## GYPB & 15 & CEBPB+MAPK1 & 16 \\
## SH3BGRL3 & 13 & CEBPE+CEBPB & 8 \\
## TMSB10 & 13 & JUN+CEBPA & 7 \\
## GYPA & 12 & ZC3HAV1+CEBPE & 7 \\
## HBG1 & 12 & AHR+FEV & 6 \\
## VIM & 8 & ETS2+MAPK1 & 6 \\
## CYBA & 5 & CDKN1C+CDKN1B & 5 \\
## RANBP1 & 5 & CEBPE+RUNX1T1 & 5 \\
## \bottomrule
## \end{tblr}
## \end{minipage}
plot_assemble(
  add_text("(A) Overlap of interaction predictions", x = 2.7, y = 1, fontsize = font_size, vjust = 1, fontface = "bold"),
  add_plot(upset_plot_pred_overlap + labs(subtitle="Set of predictions where at least 3 DL tools predict an interaction and \\emph{no change} does not (blue bars)"),
           x = 0, y = 4.5, width = 110, height = 60),
  
  add_text("(B) Deep Learning interactions not\n\\;\\;found by \\emph{no change}", x = 120, y = 1, fontsize = font_size, vjust = 1, fontface = "bold"),
  add_text(str_replace_all(shared_deep_learning_exclusives_tab, "\n", " "), x = 100, y = 34, fontsize = font_size, vjust = 1, hjust = 0),


  width = 180, height = 65, units = "mm", show_grid_lines = FALSE,
  latex_support = TRUE, filename = "../plots/response-deep_learning_specific_calls.pdf"
)
## gg[gg1]
## Warning: Removed 265 rows containing non-finite outside the scale range
## (`stat_count()`).
## gg[gg2]
## gg[gg3]
## gg[gg4]

Look at reoccuring genes

gene_response_label_df|> 
  dplyr::select(gene_name, pert_group, seed,pert_same_dir,interaction_label = label)
## # A tibble: 620,000 × 5
##    gene_name pert_group  seed pert_same_dir interaction_label
##    <chr>     <chr>      <int> <lgl>         <fct>            
##  1 ABCF1     AHR+FEV        1 TRUE          Additive         
##  2 ABRACL    AHR+FEV        1 FALSE         Additive         
##  3 ABT1      AHR+FEV        1 FALSE         Additive         
##  4 ACP1      AHR+FEV        1 FALSE         Additive         
##  5 ACTB      AHR+FEV        1 FALSE         Additive         
##  6 ACTG1     AHR+FEV        1 TRUE          Additive         
##  7 ACTN4     AHR+FEV        1 TRUE          Additive         
##  8 ADRM1     AHR+FEV        1 FALSE         Additive         
##  9 AK2       AHR+FEV        1 TRUE          Additive         
## 10 ALDOA     AHR+FEV        1 FALSE         Additive         
## # ℹ 619,990 more rows
top_gene_reoccurence <- inter_pred_dat %>%
  tidylog::inner_join(gene_response_label_df|> dplyr::select(gene_name, pert_group, seed,pert_same_dir,interaction_label = label), by = c("perturbation" = "pert_group", "gene_name", "seed")) %>%
  tidylog::filter(pert_same_dir) %>%
  (\(x){
    bind_rows(x, 
              x %>% filter(method == "no_change") %>%
                transmute(method = "ground_truth", seed, perturbation, gene_name, pred_minus_add = obs_minus_add))
  }) %>%
  dplyr::select(seed, method, perturbation, gene_name, pred_minus_add) %>%
  group_by(method, seed) %>%
  slice_max(abs(pred_minus_add), n = 100, with_ties = FALSE) %>%
  ungroup() %>%
  mutate(gene_name = fct_infreq(gene_name)) %>% 
  mutate(gene_name = fct_other(gene_name, keep = levels(gene_name)[1:4])) %>%
  count(seed, method, gene_name)  
## inner_join: added 2 columns (pert_same_dir, interaction_label)
##             > rows only in x                         (        0)
##             > rows only in dplyr::select(gene_resp.. (  310,000)
##             > matched rows                            2,790,000
##             >                                        ===========
##             > rows total                              2,790,000
## filter: removed 1,056,897 rows (38%), 1,733,103 rows remaining
top_perturbation_reoccurence <- inter_pred_dat %>%
  tidylog::inner_join(gene_response_label_df|> dplyr::select(gene_name, pert_group, seed,pert_same_dir,interaction_label = label), by = c("perturbation" = "pert_group", "gene_name", "seed")) %>%
  tidylog::filter(pert_same_dir) %>%
  (\(x){
    bind_rows(x, 
              x %>% filter(method == "no_change") %>%
                transmute(method = "ground_truth", seed, perturbation, gene_name, pred_minus_add = obs_minus_add))
  }) %>%
  dplyr::select(seed, method, perturbation, gene_name, pred_minus_add) %>%
  group_by(method, seed) %>%
  slice_max(abs(pred_minus_add), n = 100, with_ties = FALSE) %>%
  ungroup() %>%
  mutate(perturbation = fct_infreq(perturbation)) %>% 
  mutate(perturbation = fct_other(perturbation, keep = levels(perturbation)[1:4])) %>%
  count(seed, method, perturbation)  
## inner_join: added 2 columns (pert_same_dir, interaction_label)
##             > rows only in x                         (        0)
##             > rows only in dplyr::select(gene_resp.. (  310,000)
##             > matched rows                            2,790,000
##             >                                        ===========
##             > rows total                              2,790,000
## filter: removed 1,056,897 rows (38%), 1,733,103 rows remaining
ggplot_colors_six <- colorspace::qualitative_hcl(4, h = c(0, 270), c = 60, l = 70)

top_gene_reoccurence_plot <- top_gene_reoccurence %>%
  filter(method != "additive_model") |>
  mutate(method = factor(method, levels = c("ground_truth", names(method_labels))))  %>%
  ggplot(aes(x = method, y = n)) +
    geom_col(aes(fill = gene_name)) +
    scale_fill_manual(values = c(ggplot_colors_six, "grey")) +
    scale_x_discrete(labels = c("ground_truth" = "Ground Truth", method_labels)) +
    scale_y_continuous(expand = expansion(add = 0)) +
    facet_grid(vars(), vars(seed), labeller = label_both) +
    guides(x = guide_axis(angle = 90)) +
    labs(x = "", y = "No. occurrences", fill = "")

top_perturbation_reoccurence_plot <- top_perturbation_reoccurence %>%
  filter(method != "additive_model") |>
  mutate(method = factor(method, levels = c("ground_truth", names(method_labels))))  %>%
  ggplot(aes(x = method, y = n)) +
    geom_col(aes(fill = perturbation)) +
    scale_fill_manual(values = c(ggplot_colors_six, "grey")) +
    scale_x_discrete(labels = c("ground_truth" = "Ground Truth", method_labels)) +
    scale_y_continuous(expand = expansion(add = 0)) +
    facet_grid(vars(), vars(seed), labeller = label_both) +
    guides(x = guide_axis(angle = 90)) +
    labs(x = "", y = "No. occurrences", fill = "")


top_gene_reoccurence_plot

top_perturbation_reoccurence_plot

top_gene_reoccurence %>%
  filter(method != "ground_truth" & method != "additive_model") %>%
  summarize(top_six = sum(n[gene_name != "Other"]), .by = c(seed, method)) %>%
  pull(top_six) %>% summary()
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##   16.00   39.00   53.00   47.88   58.25   64.00
top_perturbation_reoccurence %>%
  filter(method == "ground_truth" & method != "additive_model") %>%
  summarize(top_six = sum(n[perturbation != "Other"]), .by = c(seed, method)) %>%
  pull(top_six) %>% summary()
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##       7      55      55      48      59      64
plot_assemble(
  add_text("(A) Reoccuring genes among top 100 interaction predictions", x = 2.7, y = 1, fontsize = font_size, vjust = 1, fontface = "bold"),
  add_plot(top_gene_reoccurence_plot, x = 0, y = 4, width = 180, height = 65),
  
  add_text("(B) Reoccuring perturbations among top 100 interaction predictions", x = 2.7, y = 70, fontsize = font_size, vjust = 1, fontface = "bold"),
  add_plot(top_perturbation_reoccurence_plot, x = 0, y = 73, width = 180, height = 65),

  width = 180, height = 140, units = "mm", show_grid_lines = FALSE,
  latex_support = TRUE, filename = "../plots/suppl-non_additive_gene_reoccurence.pdf"
)
## gg[gg1]
## gg[gg2]
## gg[gg3]
## gg[gg4]
writexl::write_xlsx(list(
  "Panel A" = top_gene_reoccurence,
  "Panel B" = top_perturbation_reoccurence
), path = "../source_data/suppl-non_additive_gene_reoccurence.xlsx")
plot_data <- inter_pred_dat %>%
  tidylog::inner_join(gene_response_label_df|> dplyr::select(gene_name, pert_group, seed,pert_same_dir,interaction_label = label), by = c("perturbation" = "pert_group", "gene_name", "seed")) %>%
  tidylog::filter(pert_same_dir) %>%
  filter(seed == 1) %>%
  mutate(rank = rank(-abs(pred_minus_add)), .by = method) %>%
  filter(gene_name %in% c("HBG2", "HBZ")) %>%
  mutate(gene_name = factor(gene_name, levels = c("HBG2", "HBZ"))) %>%
  filter(method != "additive_model")
## inner_join: added 2 columns (pert_same_dir, interaction_label)
##             > rows only in x                         (        0)
##             > rows only in dplyr::select(gene_resp.. (  310,000)
##             > matched rows                            2,790,000
##             >                                        ===========
##             > rows total                              2,790,000
## filter: removed 1,056,897 rows (38%), 1,733,103 rows remaining
pl1 <- plot_data |>
  filter(gene_name == "HBG2") |>
  mutate(perturbation = fct_reorder(perturbation, ref)) |>
  mutate(true_nonadditive = obs_minus_add > upper_thres | obs_minus_add < lower_thres) %>%
  ggplot(aes(x = perturbation)) +
    geom_hline(aes(yintercept = baseline), alpha = 0.1) +
    ungeviz::geom_hpline(aes(y = ref), color = "darkgrey", linewidth = 0.3, width = 0.8) +
    geom_rect(aes(xmin = stage(perturbation, after_scale = x - 0.5), xmax = stage(perturbation, after_scale = x + 0.5), ymin = ref - lower_thres, ymax = ref + lower_thres),
              fill = "#e2e2e2", alpha = 0.3) +
    geom_point(aes(y = ground_truth, color = interaction_label), size = 1.6, stroke = 0) +
    geom_point(aes(y = value), size = 0.6, stroke = 0, shape = "square") +
    scale_color_manual(values = non_additive_colors, drop = TRUE) +
    scale_alpha_manual(values = c("TRUE" = 1, "FALSE" = 0.2)) +
    facet_wrap(vars(method), nrow = 3, labeller = as_labeller(method_labels)) +
    guides(x = guide_axis(angle = 90), color = "none", alpha = "none") +
    labs(y = "Expression of \\emph{HBG2}", x = "Double perturbation") +
    theme(axis.text.x = element_text(size = 4))

pl2 <- plot_data |>
  filter(gene_name == "HBZ") |>
  mutate(perturbation = fct_reorder(perturbation, ref)) |>
  mutate(true_nonadditive = obs_minus_add > upper_thres | obs_minus_add < lower_thres) %>%
  ggplot(aes(x = perturbation)) +
    geom_hline(aes(yintercept = baseline), alpha = 0.1) +
    ungeviz::geom_hpline(aes(y = ref), color = "darkgrey", linewidth = 0.3, width = 0.8) +
    geom_rect(aes(xmin = stage(perturbation, after_scale = x - 0.5), xmax = stage(perturbation, after_scale = x + 0.5), ymin = ref - lower_thres, ymax = ref + lower_thres),
              fill = "#e2e2e2", alpha = 0.3) +
    geom_point(aes(y = ground_truth, color = interaction_label), size = 1.6, stroke = 0) +
    geom_point(aes(y = value), size = 0.6, stroke = 0, shape = "square") +
    scale_color_manual(values = non_additive_colors, drop = TRUE) +
    scale_alpha_manual(values = c("TRUE" = 1, "FALSE" = 0.2)) +
    facet_wrap(vars(method), nrow = 3, labeller = as_labeller(method_labels)) +
    guides(x = guide_axis(angle = 90), color = "none", alpha = "none") +
    labs(y = "Expression of \\emph{HBZ}", x = "Double perturbation") +
    theme(axis.text.x = element_text(size = 4))
plot_assemble(
  add_text("Analysis of the predicted and observed expression patterns for \\emph{HBG2} and \\emph{HBZ}", x = 2.7, y = 1, fontsize = font_size, vjust = 1, fontface = "bold"),
  add_text(paste0("Comparison of the observed expression {\\scriptsize\\textcolor{nonAdditivePurple}\\faCircle}\\,{\\scriptsize\\textcolor{nonAdditiveGrey}\\faCircle}\\,{\\scriptsize\\textcolor{nonAdditiveOrange}\\faCircle} ",
                  "against predicted value {\\tiny\\faSquare} for each double perturbation. The grey box in the background shows the additive range."),
           x = 2.7, y = 6, fontsize = font_size_small, vjust = 1),
  add_plot(pl1, x = 0, y = 8, width = 180, height = 90),
  add_plot(pl2, x = 0, y = 100, width = 180, height = 90),
  add_plot(my_get_legend(pl2 + guides(color = guide_legend(title = "", direction = "horizontal", nrow = 1))),
           x = 140, y = 170, width = 20, height = 10),

  width = 180, height = 190, units = "mm", show_grid_lines = FALSE,
  latex_support = TRUE, filename = "../plots/suppl-top_gene_plots.pdf"
)
## Measuring dimensions of: \bfseries{}Analysis of the predicted and observed expression patterns for \emph{HBG2} and \emph{HBZ}
## Running command: '/Library/TeX/texbin/pdflatex' -interaction=batchmode -halt-on-error -output-directory '/var/folders/f4/z30qmj3j55db5zh17mq85tr80000gq/T//RtmpcEUJJt/tikzDevice5ac293a3c32' 'tikzStringWidthCalc.tex'
## gg[gg1]
## gg[gg2]
## Measuring dimensions of: Expression of \emph{HBG2}
## Running command: '/Library/TeX/texbin/pdflatex' -interaction=batchmode -halt-on-error -output-directory '/var/folders/f4/z30qmj3j55db5zh17mq85tr80000gq/T//RtmpcEUJJt/tikzDevice5ac24e45fb59' 'tikzStringWidthCalc.tex'
## gg[gg3]
## Measuring dimensions of: Expression of \emph{HBZ}
## Running command: '/Library/TeX/texbin/pdflatex' -interaction=batchmode -halt-on-error -output-directory '/var/folders/f4/z30qmj3j55db5zh17mq85tr80000gq/T//RtmpcEUJJt/tikzDevice5ac25478b834' 'tikzStringWidthCalc.tex'
## gg[gg4]
## gg[gg5]
writexl::write_xlsx(plot_data, path = "../source_data/suppl-top_gene_plots.xlsx")
inter_pred_dat %>%
  filter(seed == 2) %>%
  group_by(method) %>%
  slice_max(pred_minus_add, n = 100, with_ties = FALSE) %>%
  count(gene_name) %>%
  slice_max(n, n = 3)
## # A tibble: 126 × 3
## # Groups:   method [9]
##    method         gene_name     n
##    <fct>          <chr>     <int>
##  1 no_change      HBG2          9
##  2 no_change      GAL           6
##  3 no_change      HBZ           5
##  4 additive_model ABCF1         1
##  5 additive_model ABRACL        1
##  6 additive_model ABT1          1
##  7 additive_model ACP1          1
##  8 additive_model ACTB          1
##  9 additive_model ACTG1         1
## 10 additive_model ACTN4         1
## # ℹ 116 more rows
inter_pred_dat %>%
  filter(method == "no_change") %>%
  group_by(seed, method) %>%
  slice_max(obs_minus_add, n = 100, with_ties = FALSE) %>%
  count(perturbation) %>%
  slice_max(n, n = 3)
## # A tibble: 17 × 4
## # Groups:   seed, method [5]
##     seed method    perturbation       n
##    <int> <fct>     <chr>          <int>
##  1     1 no_change CEBPB+CEBPA       52
##  2     1 no_change CEBPE+KLF1         9
##  3     1 no_change FEV+CBFA2T3        7
##  4     2 no_change CEBPB+CEBPA       56
##  5     2 no_change CEBPE+KLF1        12
##  6     2 no_change CEBPE+CEBPB        6
##  7     2 no_change PTPN12+UBASH3A     6
##  8     3 no_change CEBPE+CEBPA       52
##  9     3 no_change CEBPE+KLF1         9
## 10     3 no_change FEV+CBFA2T3        7
## 11     3 no_change ZC3HAV1+CEBPE      7
## 12     4 no_change CEBPE+KLF1        21
## 13     4 no_change ZC3HAV1+CEBPE     19
## 14     4 no_change CEBPE+PTPN12      12
## 15     5 no_change CEBPB+CEBPA       35
## 16     5 no_change CEBPE+CEBPA       28
## 17     5 no_change SET+CEBPE         14

Resource usage

resource_df <- read_tsv("../benchmark/output/single_perturbation_jobs_stats.tsv")
## Rows: 608 Columns: 7
## ── Column specification ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
## Delimiter: "\t"
## chr (6): name, metric, gpu_logged, gpu_ask, node, gpu_available
## dbl (1): value
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
norman_df <- resource_df %>%
  separate(name, into = c("dataset", "seed", "method"), sep = "-") %>%
  filter(method %in% names(method_labels)) |>
  filter(dataset == "norman_from_scfoundation") |>
  mutate(gpu_available = str_remove(gpu_available, "gpu=")) |>
  mutate(gpu_label = case_when(
    gpu_available == "A40" ~ "NVIDIA A40",
    gpu_available == "3090" ~ "NVIDIA RTX 3090",
    gpu_available == "H100" ~ "NVIDIA H100",
    gpu_available == "L40s" ~ "NVIDIA L40s",
    is.na(gpu_available) ~ "No GPU",
    .default = "Other GPU"
  ))

norman_df |>
  count(gpu_logged, gpu_available, gpu_label) 
## # A tibble: 7 × 4
##   gpu_logged                   gpu_available gpu_label           n
##   <chr>                        <chr>         <chr>           <int>
## 1 GPU: NVIDIA A40              A40           NVIDIA A40          4
## 2 GPU: NVIDIA GeForce RTX 3090 3090          NVIDIA RTX 3090    20
## 3 GPU: NVIDIA H100 PCIe        H100          NVIDIA H100        20
## 4 <NA>                         3090          NVIDIA RTX 3090    56
## 5 <NA>                         A40           NVIDIA A40         36
## 6 <NA>                         L40s          NVIDIA L40s         4
## 7 <NA>                         <NA>          No GPU             20
mem_pl <- norman_df %>%
  filter(metric == "max_mem_kbytes") %>%
  filter(method != "ground_truth") %>%
  mutate(method = factor(method, levels = names(method_labels))) %>%
  ggplot(aes(x = method, y = value * 1000)) +
    ggbeeswarm::geom_quasirandom(aes(color = gpu_label),width = 0.2, size = 0.4) +
    scale_y_continuous(labels = scales::label_bytes()) +
    scale_x_discrete(labels = method_labels) +
    labs(y = "Peak memory usage (RAM)", x = "", color = "GPU") +
    guides(x = guide_axis(angle = 90), color = guide_legend(override.aes = list(size = 1))) +
    theme(panel.grid.major.y = element_line(color = "lightgrey", linewidth = 0.2))

dur_pl <- norman_df %>%
  filter(metric == "elapsed") %>%
  filter(method != "ground_truth") %>%
  mutate(method = factor(method, levels = names(method_labels))) %>%
  ggplot(aes(x = method, y = value)) +
    ggbeeswarm::geom_quasirandom(aes(color = gpu_label), width = 0.2, size = 0.4) +
    scale_y_log10(limits = c(60, NA), breaks = c(60, 10 * 60, 60 * 60, 6 * 60 * 60, 60 * 60 * 24, 3 * 60 * 60 * 24), 
                  labels = c("1 min", "10 min", "1 hour", "6 hours", "1 day", "3 days")) +
    scale_x_discrete(labels = method_labels) +
    labs(y = "Duration", x = "", color = "GPU") +
    guides(x = guide_axis(angle = 90)) +
    theme(panel.grid.major.y = element_line(color = "lightgrey", linewidth = 0.2))

mem_pl

dur_pl

plot_assemble(
  add_text("(A)", x = 2.7, y = 1, fontsize = font_size, vjust = 1, fontface = "bold"),
  add_plot(dur_pl + guides(color = "none"), x = 0, y = 4, width = 76, height = 47.5),
  add_text("(B)", x = 82, y = 1, fontsize = font_size, vjust = 1, fontface = "bold"),
  add_plot(mem_pl, x = 82, y = 4, width = 98, height = 47.5),
  
  width = 180, height = 52, units = "mm", show_grid_lines = FALSE,
  latex_support = TRUE, filename = "../plots/suppl-resource_usage.pdf"
)
## gg[gg1]
## gg[gg2]
## gg[gg3]
## gg[gg4]
writexl::write_xlsx(resource_df, path = "../source_data/suppl-resource_usage.xlsx")

Session Info

sessionInfo()
## R version 4.4.1 (2024-06-14)
## Platform: aarch64-apple-darwin20
## Running under: macOS Sonoma 14.6
## 
## Matrix products: default
## BLAS:   /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/lib/libRblas.0.dylib 
## LAPACK: /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/lib/libRlapack.dylib;  LAPACK version 3.12.0
## 
## locale:
## [1] en_GB.UTF-8/en_GB.UTF-8/en_GB.UTF-8/C/en_GB.UTF-8/en_GB.UTF-8
## 
## time zone: Europe/London
## tzcode source: internal
## 
## attached base packages:
## [1] stats     graphics  grDevices datasets  utils     methods   base     
## 
## other attached packages:
##  [1] glue_1.7.0      lubridate_1.9.3 forcats_1.0.0   stringr_1.5.1  
##  [5] dplyr_1.1.4     purrr_1.0.2     readr_2.1.5     tidyr_1.3.1    
##  [9] tibble_3.2.1    ggplot2_3.5.1   tidyverse_2.0.0
## 
## loaded via a namespace (and not attached):
##   [1] RColorBrewer_1.1-3          strawr_0.0.92              
##   [3] jsonlite_1.8.8              shape_1.4.6.1              
##   [5] magrittr_2.0.3              ggbeeswarm_0.7.2           
##   [7] farver_2.1.2                rmarkdown_2.27             
##   [9] GlobalOptions_0.1.2         fs_1.6.4                   
##  [11] BiocIO_1.14.0               zlibbioc_1.50.0            
##  [13] vctrs_0.6.5                 locfdr_1.1-8               
##  [15] memoise_2.0.1               Cairo_1.6-2                
##  [17] Rsamtools_2.20.0            RCurl_1.98-1.16            
##  [19] htmltools_0.5.8.1           S4Arrays_1.4.1             
##  [21] curl_5.2.1                  SparseArray_1.4.8          
##  [23] gridGraphics_0.5-1          sass_0.4.9                 
##  [25] bslib_0.8.0                 legendry_0.2.0             
##  [27] zoo_1.8-12                  cachem_1.1.0               
##  [29] GenomicAlignments_1.40.0    lifecycle_1.0.4            
##  [31] iterators_1.0.14            pkgconfig_2.0.3            
##  [33] Matrix_1.7-0                R6_2.5.1                   
##  [35] fastmap_1.2.0               santoku_1.0.0              
##  [37] GenomeInfoDbData_1.2.12     tikzDevice_0.12.6          
##  [39] MatrixGenerics_1.16.0       clue_0.3-65                
##  [41] digest_0.6.36               ggbezier_0.1.0             
##  [43] colorspace_2.1-1            S4Vectors_0.42.1           
##  [45] GenomicRanges_1.56.1        labeling_0.4.3             
##  [47] tinytable_0.7.0             fansi_1.0.6                
##  [49] timechange_0.3.0            httr_1.4.7                 
##  [51] polyclip_1.10-7             abind_1.4-5                
##  [53] compiler_4.4.1              bit64_4.0.5                
##  [55] withr_3.0.1                 doParallel_1.0.17          
##  [57] BiocParallel_1.38.0         ggupset_0.4.0              
##  [59] highr_0.11                  ggforce_0.4.2              
##  [61] MASS_7.3-60.2               DelayedArray_0.30.1        
##  [63] lemur_1.2.0                 rjson_0.2.21               
##  [65] wesanderson_0.3.7           tools_4.4.1                
##  [67] vipor_0.4.7                 filehash_2.4-6             
##  [69] beeswarm_0.4.0              glmGamPoi_1.16.0           
##  [71] restfulr_0.0.15             shadowtext_0.1.4           
##  [73] grid_4.4.1                  cluster_2.1.6              
##  [75] generics_0.1.3              gtable_0.3.5               
##  [77] strapgod_0.0.4.9000         tzdb_0.4.0                 
##  [79] ungeviz_0.1.0               data.table_1.15.4          
##  [81] hms_1.1.3                   utf8_1.2.4                 
##  [83] XVector_0.44.0              BiocGenerics_0.50.0        
##  [85] ggrepel_0.9.6               foreach_1.5.2              
##  [87] pillar_1.9.0                vroom_1.6.5                
##  [89] yulab.utils_0.1.5           splines_4.4.1              
##  [91] circlize_0.4.16             tweenr_2.0.3               
##  [93] lattice_0.22-6              bit_4.0.5                  
##  [95] renv_1.0.7                  rtracklayer_1.64.0         
##  [97] tidyselect_1.2.1            SingleCellExperiment_1.26.0
##  [99] ComplexHeatmap_2.20.0       Biostrings_2.72.1          
## [101] knitr_1.48                  IRanges_2.38.1             
## [103] SummarizedExperiment_1.34.0 stats4_4.4.1               
## [105] xfun_0.50.5                 Biobase_2.64.0             
## [107] matrixStats_1.3.0           stringi_1.8.4              
## [109] UCSC.utils_1.0.0            yaml_2.3.10                
## [111] evaluate_0.24.0             codetools_0.2-20           
## [113] BiocManager_1.30.23         ggplotify_0.1.2            
## [115] cli_3.6.3                   munsell_0.5.1              
## [117] jquerylib_0.1.4             Rcpp_1.0.13                
## [119] GenomeInfoDb_1.40.1         tidylog_1.1.0              
## [121] png_0.1-8                   XML_3.99-0.17              
## [123] ggrastr_1.0.2               parallel_4.4.1             
## [125] assertthat_0.2.1            ggh4x_0.2.8                
## [127] plyranges_1.24.0            bitops_1.0-8               
## [129] scales_1.3.0                plotgardener_1.10.2        
## [131] crayon_1.5.3                writexl_1.5.4              
## [133] clisymbols_1.2.0            GetoptLong_1.0.5           
## [135] rlang_1.1.4                 cowplot_1.1.3